{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "Capsule Networks",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AYV_dMVDxyc2"
   },
   "source": [
    "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/capsule_networks/mnist.ipynb)                    \n",
    "\n",
    "## Training a Capsule Network to classify MNIST digits\n",
    "\n",
    "This is an experiment to train a Capsule Network to classify MNIST digits using PyTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AahG_i2y5tY9"
   },
   "source": [
    "Install the `labml-nn` package"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ZCzmCrAIVg0L",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "7ab15f72-c99f-4097-ecd2-5740ee9ed61c"
   },
   "source": [
    "!pip install labml-nn"
   ],
   "execution_count": 1,
   "outputs": [
    {
     "output_type": "stream",
     "text": [
      "Collecting labml-nn\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/f5/92/c454c38d613449e9cfee59809b83589bfc5463ebcf39a72126c268e31a77/labml_nn-0.4.78-py3-none-any.whl (111kB)\n",
      "\r\u001B[K     |███                             | 10kB 23.8MB/s eta 0:00:01\r\u001B[K     |██████                          | 20kB 27.8MB/s eta 0:00:01\r\u001B[K     |████████▉                       | 30kB 22.4MB/s eta 0:00:01\r\u001B[K     |███████████▉                    | 40kB 18.7MB/s eta 0:00:01\r\u001B[K     |██████████████▊                 | 51kB 17.4MB/s eta 0:00:01\r\u001B[K     |█████████████████▊              | 61kB 13.9MB/s eta 0:00:01\r\u001B[K     |████████████████████▋           | 71kB 14.2MB/s eta 0:00:01\r\u001B[K     |███████████████████████▋        | 81kB 14.1MB/s eta 0:00:01\r\u001B[K     |██████████████████████████▋     | 92kB 14.3MB/s eta 0:00:01\r\u001B[K     |█████████████████████████████▌  | 102kB 14.3MB/s eta 0:00:01\r\u001B[K     |████████████████████████████████| 112kB 14.3MB/s \n",
      "\u001B[?25hCollecting einops\n",
      "  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n",
      "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n",
      "Collecting labml>=0.4.86\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/87/4c/30c05318a66f4297babef5c3d11a34700ba2c79afc261a0c632eb8225871/labml-0.4.91-py3-none-any.whl (98kB)\n",
      "\u001B[K     |████████████████████████████████| 102kB 11.2MB/s \n",
      "\u001B[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n",
      "Collecting labml-helpers>=0.4.72\n",
      "  Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n",
      "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n",
      "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n",
      "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.86->labml-nn) (3.13)\n",
      "Collecting gitpython\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n",
      "\u001B[K     |████████████████████████████████| 163kB 46.5MB/s \n",
      "\u001B[?25hCollecting gitdb<5,>=4.0.1\n",
      "\u001B[?25l  Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n",
      "\u001B[K     |████████████████████████████████| 71kB 10.3MB/s \n",
      "\u001B[?25hCollecting smmap<4,>=3.0.1\n",
      "  Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n",
      "Installing collected packages: einops, smmap, gitdb, gitpython, labml, labml-helpers, labml-nn\n",
      "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.91 labml-helpers-0.4.72 labml-nn-0.4.78 smmap-3.0.4\n"
     ],
     "name": "stdout"
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SE2VUQ6L5zxI"
   },
   "source": [
    "Imports"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "0hJXx_g0wS2C"
   },
   "source": [
    "import torch\n",
    "\n",
    "from labml import experiment\n",
    "from labml_nn.capsule_networks.mnist import Configs"
   ],
   "execution_count": 1,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lpggo0wM6qb-"
   },
   "source": [
    "Create an experiment"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bFcr9k-l4cAg"
   },
   "source": [
    "experiment.create(name=\"capsule_networks\")"
   ],
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-OnHLi626tJt"
   },
   "source": [
    "Initialize [Capsule Network configurations](https://nn.labml.ai/capsule_networks/mnist.html)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Piz0c5f44hRo"
   },
   "source": [
    "conf = Configs()"
   ],
   "execution_count": 3,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wwMzCqpD6vkL"
   },
   "source": [
    "Set experiment configurations and assign a configurations dictionary to override configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17
    },
    "id": "e6hmQhTw4nks",
    "outputId": "ebefa8fa-93d2-4131-db95-e27f15aa3aa0"
   },
   "source": [
    "experiment.configs(conf, {'optimizer.optimizer': 'Adam',\n",
    "                         'optimizer.learning_rate': 1e-3,\n",
    "                         'inner_iterations': 5})"
   ],
   "execution_count": 4,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "<pre style=\"overflow-x: scroll;\"></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     }
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EvI7MtgJ61w5"
   },
   "source": [
    "Set PyTorch models for loading and saving"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "id": "GDlt7dp-5ALt",
    "outputId": "9701092b-c88a-4687-c90e-b193c369e59e"
   },
   "source": [
    "experiment.add_pytorch_models({'model': conf.model})"
   ],
   "execution_count": 5,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "<pre style=\"overflow-x: scroll;\">Prepare model...\n",
       "  Prepare device...\n",
       "    Prepare device_info<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t47.85ms</span>\n",
       "  Prepare device<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t52.83ms</span>\n",
       "Prepare model<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4,606.99ms</span>\n",
       "</pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     }
    }
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KJZRf8527GxL"
   },
   "source": [
    "Start the experiment and run the training loop."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 646
    },
    "id": "aIAWo7Fw5DR8",
    "outputId": "5ddbfce3-91f8-4506-e483-1640cb5a14b3"
   },
   "source": [
    "with experiment.start():\n",
    "    conf.run()"
   ],
   "execution_count": 6,
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/html": [
       "<pre style=\"overflow-x: scroll;\">\n",
       "<strong><span style=\"text-decoration: underline\">capsule_networks</span></strong>: <span style=\"color: #208FFB\">e7c08e08586711ebb3e30242ac1c0002</span>\n",
       "\t[dirty]: <strong><span style=\"color: #DDB62B\">\"\"</span></strong>\n",
       "Initialize<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t27.73ms</span>\n",
       "Prepare validator...\n",
       "  Prepare mode<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t6.00ms</span>\n",
       "  Prepare valid_loader...\n",
       "    Prepare valid_dataset...\n",
       "      Prepare dataset_transforms<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t4.04ms</span>\n",
       "    Prepare valid_dataset<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t42.26ms</span>\n",
       "<span style=\"color: #C5C1B4\"></span>\n",
       "<span style=\"color: #C5C1B4\">--------------------------------------------------</span><span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>\n",
       "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\">LABML WARNING</span></strong></span>\n",
       "<span style=\"color: #DDB62B\"><strong><span style=\"text-decoration: underline\"></span></strong></span>LabML App Warning: <span style=\"color: #60C6C8\">empty_token: </span><strong>Please create a valid token at https://app.labml.ai.</strong>\n",
       "<strong>Click on the experiment link to monitor the experiment and add it to your experiments list.</strong><span style=\"color: #C5C1B4\"></span>\n",
       "<span style=\"color: #C5C1B4\">--------------------------------------------------</span>\n",
       "<span style=\"color: #208FFB\">Monitor experiment at </span><a href='https://app.labml.ai/run?uuid=e7c08e08586711ebb3e30242ac1c0002' target='blank'>https://app.labml.ai/run?uuid=e7c08e08586711ebb3e30242ac1c0002</a>\n",
       "  Prepare valid_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t127.68ms</span>\n",
       "Prepare validator<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t295.57ms</span>\n",
       "Prepare trainer...\n",
       "  Prepare train_loader...\n",
       "    Prepare train_dataset<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t36.64ms</span>\n",
       "  Prepare train_loader<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t126.53ms</span>\n",
       "Prepare trainer<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t159.96ms</span>\n",
       "Prepare training_loop...\n",
       "  Prepare loop_count<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t34.24ms</span>\n",
       "Prepare training_loop<span style=\"color: #00A250\">...[DONE]</span><span style=\"color: #208FFB\">\t214.47ms</span>\n",
       "<strong><span style=\"color: #DDB62B\">  60,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,954ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 7,768ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.036759</span> loss.valid: <span style=\"color: #C5C1B4\">0.018877</span> accuracy.train: <span style=\"color: #C5C1B4\">0.962317</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.979800</span>  <span style=\"color: #208FFB\">78,355ms</span><span style=\"color: #D160C4\">  0:01m/  0:11m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 120,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,571ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,267ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.016659</span> loss.valid: <span style=\"color: #C5C1B4\">0.017786</span> accuracy.train: <span style=\"color: #C5C1B4\">0.989217</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.987000</span>  <span style=\"color: #208FFB\">77,000ms</span><span style=\"color: #D160C4\">  0:02m/  0:10m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 180,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,421ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,458ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.010699</span> loss.valid: <span style=\"color: #C5C1B4\">0.011324</span> accuracy.train: <span style=\"color: #C5C1B4\">0.993017</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.990400</span>  <span style=\"color: #208FFB\">76,496ms</span><span style=\"color: #D160C4\">  0:03m/  0:08m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 240,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,333ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,544ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.001724</span> loss.valid: <span style=\"color: #C5C1B4\">0.010312</span> accuracy.train: <span style=\"color: #C5C1B4\">0.995183</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992500</span>  <span style=\"color: #208FFB\">76,241ms</span><span style=\"color: #D160C4\">  0:05m/  0:07m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 300,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,393ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,584ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.025503</span> loss.valid: <span style=\"color: #C5C1B4\">0.009328</span> accuracy.train: <span style=\"color: #C5C1B4\">0.996467</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992300</span>  <span style=\"color: #208FFB\">76,131ms</span><span style=\"color: #D160C4\">  0:06m/  0:06m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 360,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,243ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,614ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.002150</span> loss.valid: <span style=\"color: #C5C1B4\">0.009803</span> accuracy.train: <span style=\"color: #C5C1B4\">0.997183</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992100</span>  <span style=\"color: #208FFB\">76,030ms</span><span style=\"color: #D160C4\">  0:07m/  0:05m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 420,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,368ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,655ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.000345</span> loss.valid: <span style=\"color: #C5C1B4\">0.011668</span> accuracy.train: <span style=\"color: #C5C1B4\">0.997750</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992400</span>  <span style=\"color: #208FFB\">75,969ms</span><span style=\"color: #D160C4\">  0:08m/  0:03m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 480,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,265ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,646ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.008524</span> loss.valid: <span style=\"color: #C5C1B4\">0.009893</span> accuracy.train: <span style=\"color: #C5C1B4\">0.998200</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992000</span>  <span style=\"color: #208FFB\">75,889ms</span><span style=\"color: #D160C4\">  0:10m/  0:02m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 540,000:  </span></strong>Train:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\">    67,440ms  </span>Valid:<span style=\"color: #C5C1B4\"> 100%</span><span style=\"color: #208FFB\"> 8,660ms  </span> loss.train: <span style=\"color: #C5C1B4\">0.000430</span> loss.valid: <span style=\"color: #C5C1B4\">0.010111</span> accuracy.train: <span style=\"color: #C5C1B4\">0.998383</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.991400</span>  <span style=\"color: #208FFB\">75,870ms</span><span style=\"color: #D160C4\">  0:11m/  0:01m  </span>\n",
       "<strong><span style=\"color: #DDB62B\"> 600,000:  </span></strong> loss.train: <span style=\"color: #C5C1B4\">0.000784</span> loss.valid: <span style=\"color: #C5C1B4\">0.009602</span> accuracy.train: <span style=\"color: #C5C1B4\">0.998817</span> accuracy.valid: <span style=\"color: #C5C1B4\">0.992500</span>\n",
       "<strong><span style=\"color: #DDB62B\">Still updating LabML App, please wait for it to complete...</span></strong></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {
      "tags": []
     }
    }
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "oBXXlP2b7XZO"
   },
   "source": [
    ""
   ],
   "execution_count": null,
   "outputs": []
  }
 ]
}